#!/bin/bash

# This file is documentation for how to get started with DeepSeek v3 on v5p-256.

# The flow of this file is as follows:
# 1. Convert the checkpoint downloaded from HuggingFace to make it compatible with MaxText.
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
# 3. Run pre-training, fine-tuning, and decoding.

set -ex
idx=$(date +%Y-%m-%d-%H-%M)
export MODEL_NAME='deepseek3-671b'
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3'

# TODO(ranran): add forward_pass_logit_checker.py test
# Installing torch for deps in forward_pass_logit_checker.py
# python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# Step 1:
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
# Non-Googlers please remember to use separate GCS paths for uploading model weights from HuggingFace ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite
export CHKPT_BUCKET=gs://maxtext-deepseek/deepseek3-671b/hf
export MODEL_BUCKET=gs://maxtext-deepseek/deepseek3-671b
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_ckpt --base_model_path ${CHKPT_BUCKET} --maxtext_model_path ${MODEL_BUCKET}/${idx} --model_size ${MODEL_NAME}

# Step 2:
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.convert_deepseek_family_unscanned_ckpt --base_model_path ${CHKPT_BUCKET} --maxtext_model_path ${MODEL_BUCKET}/${idx}/unscanned --model_size ${MODEL_NAME}

# Step 3:
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${idx}/0/items
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${idx}/unscanned/0/items

# Run pre-training - matmul implementation
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} ici_fsdp_parallelism=128 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False dataset_type=synthetic
# Run fine-tuning - matmul implementation
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_fine_tuning per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} ici_fsdp_parallelism=128 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true
# Run supervised fine-tuning - matmul implementation
python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=matmul_supervised_fine_tuning per_device_batch_size=4 enable_checkpointing=false model_name=${MODEL_NAME} steps=5 max_target_length=1024 async_checkpointing=false tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False enable_checkpointing=true ici_expert_parallelism=128 ici_fsdp_parallelism=1 dataset_type=hf
# Run decoding - matmul implementation
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=decode per_device_batch_size=1 enable_checkpointing=false model_name=${MODEL_NAME} max_prefill_predict_length=100 max_target_length=1024 tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False ici_tensor_parallelism=128 ici_fsdp_parallelism=1 prompt="I love to" scan_layers=False
